import torch
from torch import nn
m = nn.Softmax(dim=0)
input = torch.Tensor([3,2,1])
print(m(input))    #[0.6652, 0.2447, 0.0900]

m = nn.Softmax(dim=1)
input = torch.Tensor([[3,2,1],  #dim = 1, 对每一行做softmax，[0.6652, 0.2447, 0.0900]
                      [3,3,1]])
print(m(input))